from __future__ import annotations
import random 
import copy
import pandas as pd
from collections import defaultdict, Counter

from dsl_design.src.cluster import DPMM
from dsl_design.src.cluster2 import DPMM as DPMM2
from dsl_design.src.feature import Feature
from utils.distribution import N_Gaussian_Distribution
from utils.util import read_json, write_json

class Operation:
    def __init__(self, feature:Feature, operation_dsl_path):
        self.feature = feature
        self.name_mapping = {"input_flow_units": "Precond", "output_flow_units": "Postcond"}
        self.operation_dsl_path = operation_dsl_path
        self.operation_dsl_tree = {}
        # self.operation_dsl_tree_corpora = {}
        self.operation_dsl = {}
        self.curve = pd.DataFrame()

    def recursive_clustering(self, opcode, idx_list, value_list, hierarchy, iter_times=1000, alpha=0.1, regular=0.1):
        if value_list is None or len(value_list) == 0:
            return
        
        if hierarchy > 1:
            iter_times = 10
        
        result = DPMM.cluster(value_list, N_Gaussian_Distribution, len(value_list[0]), iter_times=iter_times, alpha=alpha, regular=regular)
        # result = DPMM2(X=value_list)
        if hierarchy == 1:
            self.curve[opcode] = [float(num) for num in result["log_likelihood_list"].split()]
        
        next_idx_lists, next_value_lists = self.feature.store_cluster_result(
            opcode=opcode,
            idx_list=idx_list,
            n_clusters=result["K"],
            labels=result["label"],
            hierarchy=hierarchy
        )

        if hierarchy < 3:
            for idx_list_next, value_list_next in zip(next_idx_lists, next_value_lists):
                self.recursive_clustering(opcode, idx_list_next, value_list_next, hierarchy + 1, iter_times, alpha, regular)
        
    def analyse(self, opcode):
        self.__hierarchy_tree_construction(opcode)
        self.__dsl_construction_3(opcode)
        self.__dsl_abstraction(opcode)
        self.__pattern_merge(opcode)
        # self.__dsl_tree_abstraction(opcode)
        # self.__dsl_construction(opcode)
        # self.__dsl_construction_2(opcode)

    # 将聚类结果构建为层次聚类树
    def __hierarchy_tree_construction(self, opcode):
        feature_data = self.feature.feature_data.get(opcode, [])
        
        tree = {}
        for sentence in feature_data:
            # label
            l1 = sentence.get("label-1")
            l2 = sentence.get("label-2")
            l3 = sentence.get("label-3")

            # feature
            h1 = sentence.get("hierarchy-1", {})
            h2 = sentence.get("hierarchy-2", {})
            h3 = sentence.get("hierarchy-3", {})

            if l1 and l1 not in tree:
                tree[l1] = {"pattern": {"Precond": [], "Postcond": []}, "examples": [], "child": []}
            if l2 and l2 not in tree[l1]:
                tree[l1] = {**tree[l1], l2: {"pattern": {"Device":[]}, "examples": [], "child": []}}
            if l3 and l3 not in tree[l1][l2]:
                tree[l1][l2] = {**tree[l1][l2], l3: {"pattern": {}, "examples": [], "child": []}}
            
            if l1:
                for key, value in h1.items():
                    tree[l1]["pattern"][self.name_mapping[key]].append(value)
                if not l2:
                    tree[l1]["examples"].append(sentence["sentence"])

            if l2 and h2:
                tree[l1]["child"].append(l2)
                tree[l1][l2]["pattern"]["Device"].append(h2)
                if not h3:
                    tree[l1][l2]["examples"].append(sentence["sentence"])

            # if l3:
            #     tree[l1][l2]["child"].append(l3)
            #     for key, value in h3.items():
            #         tree[l1][l2][l3]["pattern"].setdefault(key, []).append(value)
            #     tree[l1][l2][l3]["examples"].append(sentence["sentence"])

            if l3:
                tree[l1][l2]["child"].append(l3)
                for device, argkeys in h2.items():
                    for argkey in argkeys:
                        tree[l1][l2][l3]["pattern"].setdefault(device, {}).setdefault(argkey, []).extend(h3.get(argkey, []))
                tree[l1][l2][l3]["examples"].append(sentence["sentence"])
        
        tree = {label: tree[label] for label in sorted(tree)}
        self.operation_dsl_tree[opcode] = tree
        # self.operation_dsl_tree_corpora[opcode] = copy.deepcopy(tree)
        
        return tree

    # 对层次聚类树中的各层 pattern 进行抽象与合并
    def __dsl_tree_abstraction(self, opcode):
        tree = self.operation_dsl_tree[opcode]
        for _, feature1 in tree.items():
            feature1["pattern"] = self.__pattern_abstraction(feature1["pattern"], hierarchy=1)

            for label2, feature2 in feature1.items():
                if label2.isdigit():
                    feature2["pattern"] = self.__pattern_abstraction(feature2["pattern"], hierarchy=2)

                    for label3, feature3 in feature2.items():
                        if label3.isdigit():
                            feature3["pattern"] = self.__pattern_abstraction(feature3["pattern"], hierarchy=3)

    # 对单组 pattern 的抽象与合并
    def __pattern_abstraction(self, pattern, hierarchy):
        abstract_pattern = {}
        if hierarchy == 1:
            for key in ["Precond", "Postcond"]:
                abstract_pattern[key] = []

                if not pattern[key]:
                    continue

                pattern_dict = defaultdict(int)
                
                # 统计同一类中不同的 pattern 计数
                for feature in pattern[key]:
                    pattern_tuple = tuple(sorted(Counter(feature).items()))
                    pattern_dict[pattern_tuple] += 1
                
                # 选择出现次数最多的 pattern
                most_common_pattern = max(pattern_dict, key=pattern_dict.get)

                arg_name = "SlotArg" if key == "Precond" else "EmitArg"
                for phase, num in most_common_pattern:
                    sub_pattern = {
                        f"{arg_name}Num": num,
                        arg_name: phase
                    }
                    abstract_pattern[key].append(sub_pattern)
        
        elif hierarchy == 2:
            pattern_dict = defaultdict(set)
            for feature in pattern["Device"]:
                for device, argkeys in feature.items():
                    pattern_dict[device].update(argkeys)
            sorted_pattern_dict = {
                key: sorted(value)
                for key, value in sorted(pattern_dict.items(), key=lambda item: len(item[1]), reverse=True)
            }
            abstract_pattern["Device"] = sorted_pattern_dict

        elif hierarchy == 3:
            for argkey, argvalues in pattern.items():
                value_range = sorted([value for value_list in argvalues for value in value_list])
                abstract_pattern[argkey] = value_range
        
        return abstract_pattern

    # 提取层次聚类树的所有叶子结点特征的 pattern 组合，生成标准的 DSL instruction
    def __dsl_construction(self, opcode):
        opcode_feature = []

        tree = copy.deepcopy(self.operation_dsl_tree[opcode])
        for _, feature1 in tree.items():
            pattern = {
                "Precond": [],
                "Postcond": [],
                "Execution": []
            }
            pattern1 = feature1["pattern"]
            pattern["Precond"] = pattern1["Precond"]
            pattern["Postcond"] = pattern1["Postcond"]
            example = feature1["examples"].copy()

            for label2, feature2 in feature1.items():
                if label2.isdigit():
                    pattern2 = feature2["pattern"]
                    for device, argkeys in pattern2["Device"].items():
                        sub_pattern = {"DeviceType": device, "Config": {}}
                        for argkey in argkeys:
                            sub_pattern["Config"][argkey] = []
                        pattern["Execution"].append(sub_pattern)
                    example = feature2["examples"].copy()

                    for label3, feature3 in feature2.items():
                        if label3.isdigit():
                            pattern3 = feature3["pattern"]
                            for argkey, value_range in pattern3.items():
                                for execution in pattern["Execution"]:
                                    if argkey in execution["Config"]:
                                        execution["Config"][argkey] = value_range
                            example = feature3["examples"].copy()
            
            opcode_feature.append({
                "pattern": pattern, 
                "example": example
            })
        
        self.operation_dsl[opcode] = opcode_feature

        return opcode_feature

    def __dsl_construction_2(self, opcode):
        def combine_patterns(tree, node, ancestor_pattern=None):
            if ancestor_pattern is None:
                ancestor_pattern = {
                    "Precond": [],
                    "Postcond": [],
                    "Execution": {
                        "Device": [],
                        "Config": {}
                    }
                }
            node_pattern = node.get("pattern", {})
            
            # 合并当前节点的 pattern
            if any(cond in node_pattern for cond in ["Precond", "Postcond"]):
                ancestor_pattern["Precond"].extend(node_pattern.get("Precond", []))
                ancestor_pattern["Postcond"].extend(node_pattern.get("Postcond", [])) 
            elif "Device" in node_pattern:
                for device, argkeys in node_pattern["Device"].items():
                    sub_pattern = {"DeviceType": device, "Argkeys": []}
                    sub_pattern["Argkeys"].extend(argkeys)
                    ancestor_pattern["Execution"]["Device"].append(sub_pattern)
            else:
                ancestor_pattern["Execution"]["Config"] = node_pattern
            
            # 如果有子节点，递归处理每个子节点
            leaf_patterns = []
            has_child = False
            for label, child_node in node.items():
                if label.isdigit():
                    has_child = True
                    # 递归
                    child_patterns = combine_patterns(tree, child_node, copy.deepcopy(ancestor_pattern))
                    leaf_patterns.extend(child_patterns)

            # 如果是叶子节点
            if not has_child:
                leaf_patterns.append({
                    "pattern": copy.deepcopy(ancestor_pattern),
                    "example": node.get("examples", [])
                })

            return leaf_patterns
        
        opcode_feature = []
        tree = copy.deepcopy(self.operation_dsl_tree[opcode])
        for label, feature in tree.items():
            if label.isdigit(): # 遍历第一层
                leaf_patterns = combine_patterns(tree, feature)
                opcode_feature.extend(leaf_patterns)

        self.operation_dsl[opcode] = opcode_feature

        return opcode_feature

    # 根据聚类树生成 dsl （句子特征堆砌罗列），合并第二第三层的 pattern，每个第一层的 cluster 对应一个 pattern
    def __dsl_construction_3(self, opcode):
        def combine_patterns(node):
            pattern = {
                    "Precond": [],
                    "Postcond": [],
                    "Execution": {
                        "Device": [],
                        "Config": {}
                }
            }
            examples = node.get("examples", []).copy()
            sub_pattern_1 = node.get("pattern", {})
            pattern["Precond"] = sub_pattern_1.get("Precond", [])
            pattern["Postcond"] = sub_pattern_1.get("Postcond", [])
            
            for label2, feature2 in node.items():
                if label2.isdigit():
                    sub_pattern_2 = feature2.get("pattern", {})
                    examples.extend(feature2.get("examples", []))
                    pattern["Execution"]["Device"].extend(sub_pattern_2.get("Device", []))

                    for label3, feature3 in feature2.items():
                        if label3.isdigit():
                            sub_pattern_3 = feature3.get("pattern", {})
                            examples.extend(feature3.get("examples", []))
                            for device, arg_dict in sub_pattern_3.items():
                                for argkey, argvalues in arg_dict.items():
                                    pattern["Execution"]["Config"].setdefault(device, {}).setdefault(argkey, []).extend(argvalues)
            
            res_pattern = {
                "pattern": pattern,
                "examples": examples if len(examples) < 10 else random.sample(examples, 10)
            }

            return res_pattern

        opcode_feature = []
        tree = copy.deepcopy(self.operation_dsl_tree[opcode])
        for label, feature in tree.items():
            if label.isdigit():
                pattern = combine_patterns(feature)
                opcode_feature.append(pattern)
        
        self.operation_dsl[opcode] = opcode_feature

        return opcode_feature

    def __dsl_pattern_abstraction(self, pattern):
        abstract_pattern = {
                    "Precond": {},
                    "Postcond": {},
                    "Execution": []
            }
        # 第一层特征
        for key in ["Precond", "Postcond"]:
            if not pattern[key]:
                continue
            pattern_dict = defaultdict(int)
            # 统计同一类中不同的 pattern 计数
            for feature in pattern[key]:
                pattern_tuple = tuple(sorted(Counter(feature).items()))
                pattern_dict[pattern_tuple] += 1
            # 选择出现次数最多的 pattern
            most_common_pattern = max(pattern_dict, key=pattern_dict.get)
            arg_name = "SlotArg" if key == "Precond" else "EmitArg"
            sub_pattern = {
                f"{arg_name}Num": 0,
                arg_name: []
            }
            for phase, num in most_common_pattern:
                for _ in range(num):
                    sub_pattern[f"{arg_name}Num"] += 1
                    sub_pattern[arg_name].append(phase)
            abstract_pattern[key] = sub_pattern
        # 第二层特征
        pattern_dict = defaultdict(set)
        for feature in pattern["Execution"]["Device"]:
            for device, argkeys in feature.items():
                pattern_dict[device].update(argkeys)
        sorted_pattern_dict = {
            key: sorted(value)
            for key, value in sorted(pattern_dict.items(), key=lambda item: len(item[1]), reverse=True)
        }
        for device, argkeys in sorted_pattern_dict.items():
            sub_pattern = {"DeviceType": device, "Config": {}}
            for argkey in argkeys:
                sub_pattern["Config"][argkey] = []
            abstract_pattern["Execution"].append(sub_pattern)
        # 第三层特征
        # abstract_pattern["Execution"]["Config"] = {argkey: self.__sorted_unique(argvalues) for argkey, argvalues in pattern["Execution"]["Config"].items()}
        for device, arg_dict in pattern["Execution"]["Config"].items():
            for execution in abstract_pattern["Execution"]:
                if execution["DeviceType"] == device:
                    for argkey, argvalues in arg_dict.items():
                        execution["Config"][argkey].extend(self.__sorted_unique(argvalues))
        
        return abstract_pattern
    
    def __dsl_abstraction(self, opcode):
        dsl = self.operation_dsl[opcode]
        for meta_pattern in dsl[:]:
            meta_pattern["pattern"] = self.__dsl_pattern_abstraction(meta_pattern["pattern"])
            if not meta_pattern["pattern"]["Execution"] or len(meta_pattern["examples"]) < 3:
                dsl.remove(meta_pattern)
            meta_pattern["pattern"] = {
                "Precond": meta_pattern["pattern"]["Precond"],
                "Execution": meta_pattern["pattern"]["Execution"],
                "Postcond": meta_pattern["pattern"]["Postcond"]
            }
        if not dsl:
            del self.operation_dsl[opcode]

    def __pattern_merge(self, opcode):
        if opcode not in self.operation_dsl:
            return
        
        def has_same_slot_emit(precond1, precond2, postcond1, postcond2):
            return (set(precond1.get("SlotArg", [])) == set(precond2.get("SlotArg", []))) and (set(postcond1.get("EmitArg", [])) == set(postcond2.get("EmitArg", [])))
        
        def has_same_device_type(exec1, exec2):
            device_types1 = [device["DeviceType"] for device in exec1]
            device_types2 = [device["DeviceType"] for device in exec2]
            return set(device_types1) == set(device_types2)
        
        def choose_pattern(item1, item2):
            return item1 if len(item1["examples"]) >= len(item2["examples"]) else item2

        data = self.operation_dsl[opcode]
        merged_data = []
        # 用于存储已经处理过的字典的索引
        processed_indices = set()
        # 遍历每一个数据字典
        for i, item in enumerate(data):
            if i in processed_indices:
                continue  # 如果这个字典已经合并过，跳过
            # 初始化合并对象
            merged_item = {
                "pattern": {
                    "Precond": item["pattern"]["Precond"],
                    "Execution": [],
                    "Postcond": item["pattern"]["Postcond"]
                },
                "examples": []
            }
            # 合并第一个字典的内容
            merged_item["pattern"]["Execution"].extend(item["pattern"]["Execution"])
            merged_item["examples"].extend(item["examples"])
            # 与后续未处理的数据进行比较和合并
            for j in range(i+1, len(data)):
                if j in processed_indices:
                    continue  # 如果这个字典已经处理过，跳过
                # 如果 precond 和 postcond 相同，则合并
                if (has_same_slot_emit(item["pattern"]["Precond"], data[j]["pattern"]["Precond"],
                                       item["pattern"]["Postcond"], data[j]["pattern"]["Postcond"]) or 
                    has_same_device_type(item["pattern"]["Execution"], data[j]["pattern"]["Execution"])):

                    chosen_pattern = choose_pattern(item, data[j])
                    merged_item["pattern"]["Precond"] = chosen_pattern["pattern"]["Precond"]
                    merged_item["pattern"]["Postcond"] = chosen_pattern["pattern"]["Postcond"]
                    # 合并 examples
                    merged_item["examples"] = chosen_pattern["examples"].copy()
                    # 合并 Execution
                    for new_device in data[j]["pattern"]["Execution"]:
                        # 检查是否存在相同的 DeviceType
                        for existing_device in merged_item["pattern"]["Execution"]:
                            if existing_device["DeviceType"] == new_device["DeviceType"]:
                                # 如果 DeviceType 相同，合并 Config
                                for key, value in new_device["Config"].items():
                                    if key in existing_device["Config"]:
                                        existing_device["Config"][key].extend(value)
                                    else:
                                        existing_device["Config"][key] = value
                                break
                        else:
                            # 如果没有相同的 DeviceType，添加新的 Device
                            merged_item["pattern"]["Execution"].append(new_device)

                    # 标记这个字典已经处理过
                    processed_indices.add(j)

            for device_dict in merged_item["pattern"]["Execution"]:
                device_dict["Config"] = {key: self.__sorted_unique(value_list) for key, value_list in device_dict["Config"].items()}
            # 如果 examples 超过 10 个，随机取 10 个
            if len(merged_item["examples"]) > 10:
                merged_item["examples"] = random.sample(merged_item["examples"], 10)
            # 将合并后的字典加入结果列表
            merged_data.append(merged_item)
        self.operation_dsl[opcode] = merged_data

    def __sorted_unique(self, lst):
        try:
            count = Counter(lst)
            return sorted(count.keys(), key=lambda x: count[x], reverse=True)
        except:
            return lst

    def dump_result(self):
        self.feature.dump_feature_data()
        self.dump_operation_dsl()
        self.dump_log()

    def dump_operation_dsl(self):
        base_path = self.operation_dsl_path.rsplit(".json", 1)[0]
        write_json(self.operation_dsl_path, self.operation_dsl)
        write_json(f"{base_path}_tree.json", self.operation_dsl_tree)
        # write_json(f"{base_path}_tree_corpora.json", self.operation_dsl_tree_corpora) 

    def dump_log(self):
        metadata = {"Number of opcode": len(self.operation_dsl)}
        count = {opcode: len(operations) for opcode, operations in self.operation_dsl.items()}
        metadata["Average number of pattern"] = round(sum(count.values()) / len(count), 1)
        metadata.update(sorted(count.items(), key=lambda item: item[1], reverse=True))
        base_path = self.operation_dsl_path.rsplit(".json", 1)[0]
        write_json(f"{base_path}_metadata.json", metadata)

        self.curve.to_csv(f"{base_path}_curve_3.csv", index=False)
        